import numpy as np
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

transform = transforms.Compose([
    transforms.ToTensor()  # Just transform image to (1,28,28) tensor
])

MNIST_trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
MNIST_testet = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)


def get_trainset(trainset, label = 0):
    def select_number(trainset, number=1):
        ind = [i == number for i in trainset.targets]
        return ind

    sel = select_number(trainset, number=label)

    # trainset.targets = np.array(trainset.targets)
    # trainset.targets = torch.from_numpy(trainset.targets)

    X = trainset.data[sel]
    Y = trainset.targets[sel]

    return X, Y

CLASS_DATA = {}
CLASS_DATA_LENGTH = {}
for i in range(10):
    CLASS_DATA[i] = get_trainset(MNIST_trainset, label=i)
    CLASS_DATA_LENGTH[i] = len(CLASS_DATA[i][0])

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


def class_mapping(number_batch):
    return np.floor(number_batch * 10).flatten().astype(int)


class MNIST:
    def __init__(self, class_stream, learning_rate):
        self.class_stream = class_stream
        self.learning_rate = learning_rate

        self.net = Net()
        try:
            self.net.load_state_dict(torch.load("ini.pth"), strict=False)
        except:
            torch.save(self.net.state_dict(), "ini.pth")
        self.net.to("cuda:0")

        self.optim = torch.optim.SGD(self.net.parameters(), lr=learning_rate)
        self.loss_hist = []
        self.acc_hist = []

    def _loss(self, data):
        raise NotImplementedError

    def step(self):
        labels = self.class_stream.pop()
        if labels is None:
            raise

        # generate data from labels
        selected_class = class_mapping(labels)
        batch_data_image = []
        batch_data_label = []
        for class_ in selected_class:
            # randomly pick a data point
            ind = np.random.randint(CLASS_DATA_LENGTH[class_], size=1)
            batch_data_image.append(CLASS_DATA[class_][0][ind])
            batch_data_label.append(CLASS_DATA[class_][1][ind])

        batch_data_image = torch.cat(batch_data_image, 0)
        batch_data_image = batch_data_image / 255.0
        batch_data_image = torch.unsqueeze(batch_data_image, dim=1).to("cuda:0")
        batch_data_label = torch.cat(batch_data_label, 0).to("cuda:0")
        # SGD update over mini-batch
        loss = F.nll_loss(self.net(batch_data_image), batch_data_label)
        loss.backward()
        self.optim.step()
        self.optim.zero_grad()

    def eval(self, eval_loss = False):
        label = MNIST_testet.targets.to("cuda:0")
        image = MNIST_testet.data
        image = image / 255.0
        image = torch.unsqueeze(image, dim=1).to("cuda:0")

        outputs = self.net(image)
        _, predicted = torch.max(outputs.data, 1)
        acc = (predicted == label).sum().item()/ label.size(0)
        self.acc_hist.append(acc)

        if eval_loss:
            label = MNIST_trainset.targets.to("cuda:0")
            image = MNIST_trainset.data
            image = image / 255.0
            image = torch.unsqueeze(image, dim=1).to("cuda:0")

            loss = F.nll_loss(self.net(image), label)
            self.loss_hist.append(loss.item())



class MNIST2:
    def __init__(self, batch_size, learning_rate, sampling="epoch"):
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        if sampling == "epoch":
            self.indices = np.random.choice([i for i in range(60000)], size=60000, replace=False)
        else:
            self.indices = np.random.choice([i for i in range(60000)], size=60000, replace=True)
        self.count = 0

        self.net = Net()
        try:
            self.net.load_state_dict(torch.load("ini.pth"), strict=False)
        except:
            torch.save(self.net.state_dict(), "ini.pth")
        self.net.to("cuda:0")

        self.optim = torch.optim.SGD(self.net.parameters(), lr=learning_rate)
        self.loss_hist = []
        self.acc_hist = []

    def _loss(self, data):
        raise NotImplementedError

    def step(self):
        try:
            ind = self.indices[self.count: self.count+self.batch_size]
        except:
            self.indices = np.random.choice([i for i in range(60000)], size=60000, replace=False)
            self.count = (self.count+self.batch_size) % 60000
            ind = self.indices[self.count: self.count + self.batch_size]

        batch_data_image = MNIST_trainset.data[ind]
        batch_data_label = MNIST_trainset.targets[ind].to("cuda:0")

        batch_data_image = batch_data_image / 255.0
        batch_data_image = torch.unsqueeze(batch_data_image, dim=1).to("cuda:0")

        # SGD update over mini-batch
        loss = F.nll_loss(self.net(batch_data_image), batch_data_label)
        loss.backward()
        self.optim.step()

    def eval(self, eval_loss = False):
        label = MNIST_testet.targets.to("cuda:0")
        image = MNIST_testet.data
        image = image / 255.0
        image = torch.unsqueeze(image, dim=1).to("cuda:0")

        outputs = self.net(image)
        _, predicted = torch.max(outputs.data, 1)
        acc = (predicted == label).sum().item()/ label.size(0)
        self.acc_hist.append(acc)

        if eval_loss:
            label = MNIST_trainset.targets
            image = MNIST_trainset.data
            image = image / 255.0
            image = torch.unsqueeze(image, dim=1)

            loss = F.nll_loss(self.net(image), label)
            self.loss_hist.append(loss.item())

